#!/usr/bin/env python

import argparse
import os

import numpy as np
import pandas as pd


def build_parser() -> argparse.ArgumentParser:
    ap = argparse.ArgumentParser(
        description=(
            "T3 low-mass sensitivity: re-fit A_theta vs size with and without "
            "the smallest size bin, using stack_health.csv."
        )
    )
    ap.add_argument(
        "--stack-health",
        default="outputs/stack_health.csv",
        help="Path to stack_health.csv (default: outputs/stack_health.csv)",
    )
    ap.add_argument(
        "--mass-col",
        default="Mstar_bin",
        help="Mass-bin column name (default: Mstar_bin).",
    )
    ap.add_argument(
        "--size-col",
        default="R_G_bin",
        help="Size-bin column name (default: R_G_bin).",
    )
    ap.add_argument(
        "--amp-col",
        default="A_theta",
        help="Amplitude column name (default: A_theta).",
    )
    ap.add_argument(
        "--claimable-only",
        action="store_true",
        help="If set, only use rows with claimable==True (if column exists).",
    )
    ap.add_argument(
        "--out-csv",
        default="outputs/t3_lowmass_sensitivity_slopes.csv",
        help="Output CSV with slopes/CI (default: outputs/t3_lowmass_sensitivity_slopes.csv).",
    )
    return ap


def parse_RG_mid(label: str) -> float:
    """
    Parse size-bin labels like '1.5–3.0', '3-5', or '[3,5)' into a numeric midpoint.
    Returns NaN on failure.
    """
    if not isinstance(label, str):
        return np.nan
    s = label.strip()
    if not s:
        return np.nan
    # Strip brackets/parentheses
    for ch in "[]()":
        s = s.replace(ch, "")
    # Normalize unicode en dash to hyphen
    s = s.replace("–", "-")
    # If there's a comma pattern like "3,5"
    if "," in s:
        parts = s.split(",")
    else:
        parts = s.split("-")
    if len(parts) != 2:
        # maybe it's already a number
        try:
            return float(s)
        except Exception:
            return np.nan
    try:
        a = float(parts[0])
        b = float(parts[1])
        return 0.5 * (a + b)
    except Exception:
        return np.nan


def linreg_with_ci(x: np.ndarray, y: np.ndarray):
    """
    Simple linear regression y = a + b x with an approximate 1-sigma CI on slope.

    Returns dict with keys:
      slope, intercept, slope_ci_lo, slope_ci_hi, n, r2
    """
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    mask = np.isfinite(x) & np.isfinite(y)
    x = x[mask]
    y = y[mask]
    n = len(x)
    if n < 2 or np.allclose(x, x[0]):
        return {
            "slope": np.nan,
            "intercept": np.nan,
            "slope_ci_lo": np.nan,
            "slope_ci_hi": np.nan,
            "n": n,
            "r2": np.nan,
        }
    x_mean = x.mean()
    y_mean = y.mean()
    dx = x - x_mean
    dy = y - y_mean
    Sxx = np.sum(dx * dx)
    Sxy = np.sum(dx * dy)
    slope = Sxy / Sxx
    intercept = y_mean - slope * x_mean

    # Residuals & variance
    residuals = y - (slope * x + intercept)
    if n > 2:
        s2 = np.sum(residuals ** 2) / (n - 2)
        se_slope = np.sqrt(s2 / Sxx)
        # Approximate 68% CI: slope ± 1 * se_slope
        slope_ci_lo = slope - se_slope
        slope_ci_hi = slope + se_slope
        # R^2
        if np.var(y) > 0:
            r2 = 1.0 - np.sum(residuals ** 2) / np.sum((y - y_mean) ** 2)
        else:
            r2 = np.nan
    else:
        # With only 2 points, no DOF to estimate variance; just report slope.
        se_slope = np.nan
        slope_ci_lo = np.nan
        slope_ci_hi = np.nan
        r2 = np.nan

    return {
        "slope": slope,
        "intercept": intercept,
        "slope_ci_lo": slope_ci_lo,
        "slope_ci_hi": slope_ci_hi,
        "n": n,
        "r2": r2,
    }


def main():
    ap = build_parser()
    args = ap.parse_args()

    if not os.path.exists(args.stack_health):
        raise SystemExit(
            f"ERROR: stack health CSV not found at {args.stack_health}.\n"
            f"Run scripts/make_stack_health_table.py first."
        )

    df = pd.read_csv(args.stack_health)

    if args.mass_col not in df.columns:
        raise SystemExit(
            f"ERROR: mass column '{args.mass_col}' not found in stack_health.csv."
        )
    if args.size_col not in df.columns:
        raise SystemExit(
            f"ERROR: size column '{args.size_col}' not found in stack_health.csv."
        )
    if args.amp_col not in df.columns:
        raise SystemExit(
            f"ERROR: amplitude column '{args.amp_col}' not found in stack_health.csv."
        )

    # Optional claimable filter
    if args.claimable_only and "claimable" in df.columns:
        df = df[df["claimable"] == True].copy()

    # Require a declared window and finite amplitude
    df = df[(df["win_nbins"] > 0) & (df[args.amp_col].notna())].copy()

    # Parse numeric size midpoints
    df["RG_mid_bin_local"] = df[args.size_col].apply(parse_RG_mid)

    # Unique mass bins
    mass_bins = df[args.mass_col].unique()
    mass_bins_sorted = sorted(mass_bins)

    records = []

    for m in mass_bins_sorted:
        sub = df[df[args.mass_col] == m].copy()
        if sub.empty:
            continue

        # ---- Baseline: all size bins with windows ----
        x_all = sub["RG_mid_bin_local"].values
        y_all = sub[args.amp_col].values
        reg_all = linreg_with_ci(x_all, y_all)
        rec_all = {
            "variant": "baseline_all_sizes",
            "Mstar_bin": m,
            **reg_all,
        }
        records.append(rec_all)

        # ---- Low-mass-only variant: drop smallest size bin ----
        # Apply only if this is the lowest mass bin (string sort is fine with your labels)
        if m == mass_bins_sorted[0]:
            # Find the minimum RG_mid_bin for this mass bin
            finite_mask = np.isfinite(sub["RG_mid_bin_local"].values)
            if finite_mask.any():
                rg_vals = sub["RG_mid_bin_local"].values
                rg_min = np.min(rg_vals[finite_mask])
                sub_cut = sub[sub["RG_mid_bin_local"] > rg_min].copy()
                if len(sub_cut) >= 2:
                    x_cut = sub_cut["RG_mid_bin_local"].values
                    y_cut = sub_cut[args.amp_col].values
                    reg_cut = linreg_with_ci(x_cut, y_cut)
                else:
                    reg_cut = {
                        "slope": np.nan,
                        "intercept": np.nan,
                        "slope_ci_lo": np.nan,
                        "slope_ci_hi": np.nan,
                        "n": len(sub_cut),
                        "r2": np.nan,
                    }
                rec_cut = {
                    "variant": "lowmass_drop_smallest_size",
                    "Mstar_bin": m,
                    **reg_cut,
                }
                records.append(rec_cut)

    out_df = pd.DataFrame.from_records(records)
    out_dir = os.path.dirname(args.out_csv)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
    out_df.to_csv(args.out_csv, index=False)

    print("\n=== T3 size–amplitude slopes (sensitivity) ===\n")
    if not out_df.empty:
        print(out_df.to_string(index=False))
    else:
        print("[info] No regression results to report (check filters).")


if __name__ == "__main__":
    main()
